''' For CNN training and testing. '''


import os
import timeit
import torch
import torch.nn as nn
import numpy as np
from torch.nn import functional as F


def denorm(x, means, stds):
    '''
    x: torch tensor
    means: means for normalization
    stds: stds for normalization
    '''
    x_ch0 = torch.unsqueeze(x[:, 0], 1) * (stds[0] / 0.5) + (means[0] - 0.5) / 0.5
    x_ch1 = torch.unsqueeze(x[:, 1], 1) * (stds[1] / 0.5) + (means[1] - 0.5) / 0.5
    x_ch2 = torch.unsqueeze(x[:, 2], 1) * (stds[2] / 0.5) + (means[2] - 0.5) / 0.5
    x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
    return x


''' function for cnn training '''
def train_cnn(net, net_name, trainloader, testloader, epochs, resume_epoch=0, save_freq=[100, 150], lr_base=0.1, lr_decay_factor=0.1, lr_decay_epochs=[150, 250], weight_decay=1e-4, extract_feature=False, net_decoder=None, lambda_reconst=0, train_means=None, train_stds=None, path_to_ckpt = None):

    ''' learning rate decay '''
    def adjust_learning_rate(optimizer, epoch):
        """decrease the learning rate """
        lr = lr_base

        num_decays = len(lr_decay_epochs)
        for decay_i in range(num_decays):
            if epoch >= lr_decay_epochs[decay_i]:
                lr = lr * lr_decay_factor
            #end if epoch
        #end for decay_i
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr


    net = net.cuda()
    criterion = nn.CrossEntropyLoss()
    params = list(net.parameters())
    
    if lambda_reconst>0 and net_decoder is not None:
        net_decoder = net_decoder.cuda()
        criterion_reconst = nn.MSELoss()
        params += list(net_decoder.parameters())

    optimizer = torch.optim.SGD(params, lr = lr_base, momentum= 0.9, weight_decay=weight_decay)
    # optimizer = torch.optim.Adam(params, lr = lr_base, betas=(0, 0.999), weight_decay=weight_decay)

    if path_to_ckpt is not None and resume_epoch>0:
        save_file = path_to_ckpt + "/{}_checkpoint_epoch_{}.pth".format(net_name, resume_epoch)
        if lambda_reconst>0 and net_decoder is not None:
            checkpoint = torch.load(save_file)
            net.load_state_dict(checkpoint['net_state_dict'])
            net_decoder.load_state_dict(checkpoint['net_decoder_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            torch.set_rng_state(checkpoint['rng_state'])
        else:
            checkpoint = torch.load(save_file)
            net.load_state_dict(checkpoint['net_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            torch.set_rng_state(checkpoint['rng_state'])
    #end if


    start_time = timeit.default_timer()
    for epoch in range(resume_epoch, epochs):
        net.train()
        train_loss = 0
        train_loss_const = 0
        adjust_learning_rate(optimizer, epoch)
        for batch_idx, (batch_train_images, batch_train_labels) in enumerate(trainloader):

            batch_train_images = batch_train_images.type(torch.float).cuda()
            batch_train_labels = batch_train_labels.type(torch.long).cuda()

            #Forward pass
            if lambda_reconst>0 and net_decoder is not None and extract_feature:
                outputs, features = net(batch_train_images)
                batch_reconst_images = net_decoder(features)
                class_loss = criterion(outputs, batch_train_labels)
                if train_means and train_stds:
                    batch_train_images = denorm(batch_train_images, train_means, train_stds) #decoder use different normalization constants
                reconst_loss = criterion_reconst(batch_reconst_images, batch_train_images)
                loss = class_loss + lambda_reconst * reconst_loss
            else:
                if extract_feature:
                    outputs, _ = net(batch_train_images)
                else:
                    outputs = net(batch_train_images)
                ## standard CE loss
                loss = criterion(outputs, batch_train_labels)

            #backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.cpu().item()
            if lambda_reconst>0 and net_decoder is not None and extract_feature:
                train_loss_const += reconst_loss.cpu().item()
        #end for batch_idx
        test_acc = test_cnn(net, testloader, extract_feature, verbose=False)

        if lambda_reconst>0 and net_decoder is not None and extract_feature:
            print("{}, lambda:{:.3f}: [epoch {}/{}] train_loss:{:.3f}, reconst_loss:{:.3f}, test_acc:{:.3f} Time:{:.4f}".format(net_name, lambda_reconst, epoch+1, epochs, train_loss/(batch_idx+1), train_loss_const/(batch_idx+1), test_acc, timeit.default_timer()-start_time))
        else:
            print('%s: [epoch %d/%d] train_loss:%.3f, test_acc:%.3f Time: %.4f' % (net_name, epoch+1, epochs, train_loss/(batch_idx+1), test_acc, timeit.default_timer()-start_time))

        # save checkpoint
        if path_to_ckpt is not None and ((epoch+1) in save_freq or (epoch+1) == epochs) :
            save_file = path_to_ckpt + "/{}_checkpoint_epoch_{}.pth".format(net_name, epoch+1)
            if lambda_reconst>0 and net_decoder is not None:
                torch.save({
                        'net_state_dict': net.state_dict(),
                        'net_decoder_state_dict': net_decoder.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'rng_state': torch.get_rng_state()
                }, save_file)
            else:
                torch.save({
                        'net_state_dict': net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'rng_state': torch.get_rng_state()
                }, save_file)
    #end for epoch

    return net



def test_cnn(net, testloader, extract_feature=False, verbose=False):

    net = net.cuda()
    net.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
    with torch.no_grad():
        correct = 0
        total = 0
        for batch_idx, (images, labels) in enumerate(testloader):
            images = images.type(torch.float).cuda()
            labels = labels.type(torch.long).cuda()
            if extract_feature:
                outputs,_ = net(images)
            else:
                outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        if verbose:
            print('Test Accuracy of the model on the 10000 test images: {} %'.format(100.0 * correct / total))
    return 100.0 * correct / total
